feat(dpa4): use edge force and atomic virial#5518
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the PyTorch SeZM/DPA4 implementation to compute forces and virials via an edge-based gradient (“edge-force scatter”) rather than differentiating w.r.t. coordinates, and adjusts the surrounding compile/export/test infrastructure accordingly. This aligns the SeZM path with an edge-centric internal representation and adds validation tests for the new force/virial assembly.
Changes:
- Added
edge_energy_deriv()to assemble extended force, global virial, and per-atom virial by scattering gradients taken w.r.t. per-edge displacement vectors. - Updated
SeZMModelto build both local and extended edge index spaces, detach edge vectors as the autograd leaf, and route ZBL bridging through the edge-formInterPotential. - Expanded/updated PT tests to use edge-based
InterPotentialinputs and to finite-difference validate force/virial/atom-virial consistency; adjusted compile-cache key expectations.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/model/test_sezm_spin_model.py | Updates bridging-mask and compile-cache assertions for new edge-based API and cache key shape. |
| source/tests/pt/model/test_sezm_model.py | Migrates InterPotential tests to edge inputs and adds finite-difference validation for edge-force/virial/atomic-virial assembly. |
| deepmd/pt/model/model/transform_output.py | Adds edge_energy_deriv() to compute force/virial/atomic-virial via edge gradients and explicit scatters. |
| deepmd/pt/model/model/sezm_model.py | Switches SeZM force/virial computation to edge-force scatter; updates edge-list builder to return extended indices; converts ZBL to edge form; simplifies compile-cache key. |
| deepmd/pt/model/descriptor/sezm_nn/utils.py | Avoids creating a scalar tensor for eps_sq in safe_norm. |
| deepmd/pt/model/descriptor/sezm_nn/edge_cache.py | Reworks masked-edge canonical-direction padding using F.pad; removes a small scalar tensor allocation. |
| deepmd/pt/entrypoints/freeze_pt2.py | Changes SeZM .pt2 freeze default to export atomic virial and updates docstring accordingly. |
Comments suppressed due to low confidence (1)
deepmd/pt/entrypoints/freeze_pt2.py:484
- Changing
freeze_sezm_to_pt2()defaultatomic_virialfromFalsetoTruechanges the default.pt2output contract (extra per-atom virial output keys + metadatado_atomic_virial=true) for all callers that don’t pass the flag, includingdeepmd.pt.entrypoints.main.freeze(). This is a behavior/API change that can break downstream consumers expecting the previous default output set; consider keeping the defaultFalseand letting callers opt in explicitly (or plumb a CLI option).
atomic_virial: bool = True,
) -> None:
"""Freeze a SeZM checkpoint into an AOTInductor ``.pt2`` archive.
Parameters
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughThis PR refactors the SeZM model's force and virial computation pipeline by shifting the differentiation anchor from extended coordinates to an edge-vector autograd leaf, introduces a new ChangesSeZM edge-based force/virial refactoring
Freeze PT2 atomic_virial default flip
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)
2410-2413:⚠️ Potential issue | 🟠 Major | ⚡ Quick winKeep coincident edges in the sparse edge list.
len_positive = edge_len2 > 1e-10changes the runtime model, not just the trace sample. The eager edge-cache path keeps valid edges atr == 0and relies onsafe_norm/ clamp logic downstream to stay finite, but this filter drops them entirely. That makes the sparse SeZM path disagree with the eager path on overlapping configurations and also suppresses the ZBL bridge exactly where the short-range repulsion should be strongest.Suggested fix
- len_positive = edge_len2 > 1e-10 - edge_mask_actual = valid_flat & src_local_valid & len_positive + edge_mask_actual = valid_flat & src_local_validIf the trace-only clamped sample still needs self-edge sanitization, handle that in the trace-input preparation instead of changing the runtime edge semantics.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/sezm_model.py` around lines 2410 - 2413, The current filter removes coincident edges by using len_positive = edge_len2 > 1e-10 and applying it to edge_mask_actual, which changes runtime semantics and breaks parity with the eager path; remove the len_positive check from the runtime edge_mask_actual (keep valid_flat & src_local_valid) so edges with r==0 remain in the sparse SeZM edge list, and if trace-only sanitization is required, perform clamping/sanitization in the trace-input preparation logic instead of in the runtime filter; update references around src_local_valid, edge_len2, edge_mask_actual and ensure downstream safe_norm/clamp logic handles zero-distance cases as before.
🧹 Nitpick comments (1)
source/tests/pt/model/test_sezm_model.py (1)
1218-1230: ⚡ Quick winAdd a direct ZBL virial finite-difference check.
The
bridging_method="ZBL"branch is only validated byatom_virial.sum(dim=1) == virial. If the bridged virial is scattered with the same sign/indexing bug into both tensors, this still passes. Reusing the strain finite-difference check forZBLwould pin the new bridged virial path itself.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/pt/model/test_sezm_model.py` around lines 1218 - 1230, In test_atom_virial_sums_to_global_virial, add a direct finite-difference check for the "ZBL" branch rather than only asserting atom_virial.sum == virial; after building the model with _build_model(bridging_method="ZBL") and calling model(coord, atype, box=box, do_atomic_virial=True), compute the global virial via the existing numerical finite-difference helper used for strain checks (reuse the same FD routine used elsewhere in the test suite) and assert that the model's out["virial"] matches that FD-computed virial within tolerances — this pins the bridged virial path itself and prevents sign/index scatter from passing the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 2410-2413: The current filter removes coincident edges by using
len_positive = edge_len2 > 1e-10 and applying it to edge_mask_actual, which
changes runtime semantics and breaks parity with the eager path; remove the
len_positive check from the runtime edge_mask_actual (keep valid_flat &
src_local_valid) so edges with r==0 remain in the sparse SeZM edge list, and if
trace-only sanitization is required, perform clamping/sanitization in the
trace-input preparation logic instead of in the runtime filter; update
references around src_local_valid, edge_len2, edge_mask_actual and ensure
downstream safe_norm/clamp logic handles zero-distance cases as before.
---
Nitpick comments:
In `@source/tests/pt/model/test_sezm_model.py`:
- Around line 1218-1230: In test_atom_virial_sums_to_global_virial, add a direct
finite-difference check for the "ZBL" branch rather than only asserting
atom_virial.sum == virial; after building the model with
_build_model(bridging_method="ZBL") and calling model(coord, atype, box=box,
do_atomic_virial=True), compute the global virial via the existing numerical
finite-difference helper used for strain checks (reuse the same FD routine used
elsewhere in the test suite) and assert that the model's out["virial"] matches
that FD-computed virial within tolerances — this pins the bridged virial path
itself and prevents sign/index scatter from passing the test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c91a1fa5-678f-433e-8e55-1d20f67ae5b2
📒 Files selected for processing (7)
deepmd/pt/entrypoints/freeze_pt2.pydeepmd/pt/model/descriptor/sezm_nn/edge_cache.pydeepmd/pt/model/descriptor/sezm_nn/utils.pydeepmd/pt/model/model/sezm_model.pydeepmd/pt/model/model/transform_output.pysource/tests/pt/model/test_sezm_model.pysource/tests/pt/model/test_sezm_spin_model.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5518 +/- ##
==========================================
+ Coverage 81.52% 82.19% +0.67%
==========================================
Files 872 891 +19
Lines 97964 101600 +3636
Branches 4241 4242 +1
==========================================
+ Hits 79865 83511 +3646
+ Misses 16795 16786 -9
+ Partials 1304 1303 -1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
…eepmodeling#5540) PR-3 (final) of the DPA4/SeZM porting series — pt_expt **inference**: freeze to `.pt2`, Python `DeepEval`, pt→pt_expt checkpoint interop, C++ single-rank, and LAMMPS single-rank. Follows PR-1 (deepmodeling#5515, dpmodel core) and PR-2 (deepmodeling#5522, pt_expt training/export). ## What's included - **Model freeze to `.pt2`** (`deepmd/pt_expt/model/ener_model.py`, `deepmd/dpmodel/.../ener_model.py`, `deepmd/dpmodel/descriptor/dpa4.py`): register `EnergyModel` under `sezm_ener`/`dpa4_ener` model-type aliases so `BaseModel.deserialize` resolves a standard DPA4 energy model (whose fitting type is `sezm_ener`). Fixed a `torch.export` specialization where `int()` on symbolic shapes baked `nf*nloc` (embedding/so2/attention). - **NoPBC export fix** (`deepmd/dpmodel/descriptor/dpa4.py`): the `atype_ext[:, :nloc]` slice emitted a spurious `Ne(nall, nloc)` shape guard that crashed the compiled artifact when `nall==nloc` (no ghosts); replaced with `xp_take_first_n` (index_select). NoPBC now matches PBC. - **pt→pt_expt checkpoint interop** (`deepmd/pt_expt/model/model.py`): `BaseModel.deserialize` unwraps pt's bespoke `SeZMModel` serialization (`type:"SeZM"`, nested `sezm_atomic` atomic model with the pt-only dens head), validates versions, rejects unsupported features (bridging/lora/dens/active_mode) with `NotImplementedError`, and delegates to the standard path. - **Warn on silently-ignored flags** (`use_amp` descriptor, `enable_tf32` model): warn-once instead of silent drop. ## Tests - **Model freeze** `source/tests/pt_expt/model/test_dpa4_export.py`: dual-artifact `.pt2`, metadata, AOTI load, artifact-vs-eager parity (1e-10). *(CI-skipped — AOTI is slow; run locally.)* - **DeepEval parity vs pt** `source/tests/pt_expt/infer/test_dpa4_deep_eval.py`: pt `.pt` vs pt_expt `.pt2` energy/force/global-virial/atom-energy at fp64 1e-10, **PBC and NoPBC**; doubles as the checkpoint-interop proof. Per-atom virial compared by sum (pt's edge-scatter from deepmodeling#5518 redistributes it; global virial matches). *(CI-skipped — AOTI.)* - **Interop unit tests** `source/tests/pt_expt/model/test_dpa4_interop.py` (CI-runnable, no AOTI): happy-path pt-checkpoint→pt_expt round-trip + every guard branch + version validation + `@variables` filtering. - **Alias deserialize guard** + **use_amp/enable_tf32 warn-once** tests (CI-runnable). - **Fixture generator** `source/tests/infer/gen_dpa4.py` (+ wired into `source/install/test_cc_local.sh`). - **C++ single-rank** `source/api_cc/tests/test_deeppot_dpa4_ptexpt.cc`: 20 tests (double+float), dpa3-matched tolerances. Validated locally. - **LAMMPS single-rank** `source/lmp/tests/test_lammps_dpa4_pt2.py`: parity + `atom_modify map yes` + the deepmodeling#5450 no-atom-map fail-fast. **Validated on a GPU box (7 passed).** PR-1 parity suites stay green; the small dpmodel edits are parity-revalidated. ## Known limitations - **Single-rank only.** Multi-rank/MPI LAMMPS for DPA4 is deferred (no live multi-rank cell; the with-comm artifact compiles but its runtime is not exercised). DPA4 is a message-passing descriptor, so multi-rank follows the existing deepmodeling#5450/deepmodeling#5430 machinery in a later PR. - **No `.pth` (torch.jit) DPA4** — the pt backend has no `sezm_ener` *model* registration, so `.pth` freeze of a standard DPA4 energy model isn't available; not needed for the pt_expt inference path. - **Per-atom virial** is not compared element-wise pt-vs-pt_expt (only its global sum) — deepmodeling#5518 changed pt's edge-scatter distribution; both are correct, the distribution differs. - **AOTI tests are CI-skipped** (multi-minute compile) — the freeze/DeepEval paths are validated locally, not in CI; the interop/alias/warn tests give CI coverage of the non-AOTI logic. - **fp64 only**; fp32 freeze untested. CUDA validated at LAMMPS level on a GPU box; the AOTI parity numbers are from CPU fp64. - **`use_amp`/`enable_tf32`** remain functionally ignored (now warned) — by design for this series. - pt SeZM features out of scope (guarded `NotImplementedError`): spin, ZBL bridging, LoRA, dens/direct-force/denoising heads, SO3 grid projection, GridMLP, SO(2) attention extensions. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit # Release Notes * **New Features** * Enabled DPA4 model inference via the pt_expt backend using dual-artifact compilation. * Registered the EnergyModel under additional aliases: `sezm_ener` and `dpa4_ener`. * **Improvements** * Improved dynamic/symbolic shape handling across DPA4 components for export/tracing stability. * Enhanced pt SeZM/DPA4 checkpoint deserialization and normalization for interoperability. * Added one-time warnings when `use_amp` or `enable_tf32` settings are ineffective. * **Tests** * Added C++ and Python coverage for pt2 inference, LAMMPS integration, model export/freeze, parity, interop, and warning behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary by CodeRabbit
New Features
Changes
Tests